- Notifications
You must be signed in to change notification settings - Fork 251
/
Copy pathSoftmax_Sklearn.py
47 lines (36 loc) · 1.42 KB
/
Softmax_Sklearn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#-*- coding:utf-8 -*-
# &Author AnFany
importsklearnassk
fromIris_DataimportDataassmdata
importnumpyasnp
fromsklearn.linear_modelimportLogisticRegression
sklr=LogisticRegression(multi_class='multinomial', solver='sag', C=200, max_iter=10000)
#格式化输出混淆矩阵
fromprettytableimportPrettyTable
defconfusion(realy, outy, method='Sklearn'):
mix=PrettyTable()
type=sorted(list(set(realy.T[0])), reverse=True)
mix.field_names= [method] + ['预测:%d类'%siforsiintype]
# 字典形式存储混淆矩阵数据
cmdict= {}
forjkjintype:
cmdict[jkj] = []
forhhintype:
hu=len(['0'forjjinrange(len(realy)) ifrealy[jj][0] ==jkjandouty[jj][0] ==hh])
cmdict[jkj].append(hu)
# 输出表格
forfuintype:
mix.add_row(['真实:%d类'%fu] +cmdict[fu])
returnmix
# 将独热编码的类别变为标识为1,2,3的类别
deftransign(eydata):
ysign= []
forhhineydata:
ysign.append([list(hh).index(1) +1])
returnnp.array(ysign)
# 主函数
if__name__=='__main__':
regre=sklr.fit(smdata[0], transign(smdata[1]).T[0])
predata=np.array([sklr.predict(smdata[0])]).T
print('系数为:\n', np.hstack((sklr.coef_, np.array([sklr.intercept_]).T)).T)
print('混淆矩阵:\n', confusion(transign(smdata[1]), predata))